{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eec23018",
   "metadata": {},
   "source": [
    "# Kinematic car sensor fusion example\n",
    "RMM, 24 Feb 2022 (updated 23 Feb 2023)\n",
    "\n",
    "In this example we work through estimation of the state of a car changing\n",
    "lanes with two different sensors available: one with good longitudinal accuracy\n",
    "and the other with good lateral accuracy.\n",
    "\n",
    "All calculations are done in discrete time, using both the form of the Kalman\n",
    "filter in Theorem 7.2 and the predictor corrector form."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "107a6613",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy as sp\n",
    "import matplotlib.pyplot as plt\n",
    "import control as ct\n",
    "import control.optimal as opt\n",
    "import control.flatsys as fs\n",
    "\n",
    "# Define some line styles for later use\n",
    "ebarstyle = {'elinewidth': 0.5, 'capsize': 2}\n",
    "xdstyle = {'color': 'k', 'linestyle': '--', 'linewidth': 0.5, \n",
    "           'marker': '+', 'markersize': 4}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea8807a4",
   "metadata": {},
   "source": [
    "## System definition\n",
    "\n",
    "We make use of a simple model for a vehicle navigating in the plane, known as the \"bicycle model\".  The kinematics of this vehicle can be written in terms of the contact point $(x, y)$ and the angle $\\theta$ of the vehicle with respect to the horizontal axis:\n",
    "\n",
    "<table>\n",
    "<tr>\n",
    "    <td width=\"50%\"><img src=\"https://fbswiki.org/wiki/images/5/52/Kincar.png\" width=480></td>\n",
    "    <td width=\"50%\">\n",
    "$$\n",
    "\\begin{aligned}\n",
    "  \\dot x &= \\cos\\theta\\, v \\\\\n",
    "  \\dot y &= \\sin\\theta\\, v \\\\\n",
    "  \\dot\\theta &= \\frac{v}{l} \\tan \\delta\n",
    "\\end{aligned}\n",
    "$$\n",
    "    </td>\n",
    "</tr>\n",
    "</table>\n",
    "\n",
    "The input $v$ represents the velocity of the vehicle and the input $\\delta$ represents the turning rate. The parameter $l$ is the wheelbase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a04106f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Vehicle steering dynamics\n",
    "#\n",
    "# System state: x, y, theta\n",
    "# System input: v, phi\n",
    "# System output: x, y\n",
    "# System parameters: wheelbase, maxsteer\n",
    "#\n",
    "from kincar import kincar, plot_lanechange\n",
    "print(kincar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c048ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate a trajectory for the vehicle\n",
    "# Define the endpoints of the trajectory\n",
    "x0 = [0., -2., 0.]; u0 = [10., 0.]\n",
    "xf = [40., 2., 0.]; uf = [10., 0.]\n",
    "Tf = 4\n",
    "\n",
    "# Find a trajectory between the initial condition and the final condition\n",
    "traj = fs.point_to_point(kincar, Tf, x0, u0, xf, uf, basis=fs.PolyFamily(6))\n",
    "\n",
    "# Create the desired trajectory between the initial and final condition\n",
    "Ts = 0.1\n",
    "# Ts = 0.5\n",
    "timepts = np.arange(0, Tf + Ts, Ts)\n",
    "xd, ud = traj.eval(timepts)\n",
    "\n",
    "plot_lanechange(timepts, xd, ud)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeeaa39e",
   "metadata": {},
   "source": [
    "### Discrete time system model\n",
    "\n",
    "For the model that we use for the Kalman filter, we take a simple discretization using the approximation that $\\dot x = (x[k+1] - x[k])/T_s$ where $T_s$ is the sampling time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2469c60e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Create a discrete-time, linear model\n",
    "#\n",
    "\n",
    "# Linearize about the starting point\n",
    "linsys = ct.linearize(kincar, x0, u0)\n",
    "\n",
    "# Create a discrete-time model by hand\n",
    "Ad = np.eye(linsys.nstates) + linsys.A * Ts\n",
    "Bd = linsys.B * Ts\n",
    "discsys = ct.ss(Ad, Bd, np.eye(linsys.nstates), 0, dt=Ts)\n",
    "print(discsys);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "084c5ae8",
   "metadata": {},
   "source": [
    "### Sensor model\n",
    "\n",
    "We assume that we have two sensors: one with good longitudinal accuracy and the other with good lateral accuracy.  For each sensor we define the map from the state space to the sensor outputs, the covariance matrix for the measurements, and a white noise signal (now in discrete time).\n",
    "\n",
    "Note: we pass the keyword `dt` to the `white_noise` function so that the white noise is consistent with a discrete-time model (so the covariance is _not_ rescaled by $\\sqrt{dt}$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a19d109",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sensor #1: longitudinal\n",
    "C_lon = np.eye(2, discsys.nstates)\n",
    "Rw_lon = np.diag([0.1 ** 2, 1 ** 2])\n",
    "W_lon = ct.white_noise(timepts, Rw_lon, dt=Ts)\n",
    "\n",
    "# Sensor #2: lateral\n",
    "C_lat = np.eye(2, discsys.nstates)\n",
    "Rw_lat = np.diag([1 ** 2, 0.1 ** 2])\n",
    "W_lat = ct.white_noise(timepts, Rw_lat, dt=Ts)\n",
    "\n",
    "# Plot the noisy signals\n",
    "plt.subplot(2, 1, 1)\n",
    "Y = xd[0:2] + W_lon\n",
    "plt.plot(Y[0], Y[1])\n",
    "plt.plot(xd[0], xd[1], **xdstyle)\n",
    "plt.xlabel(\"$x$ position [m]\")\n",
    "plt.ylabel(\"$y$ position [m]\")\n",
    "plt.title(\"Sensor #1 (longitudinal)\")\n",
    "               \n",
    "plt.subplot(2, 1, 2)\n",
    "Y = xd[0:2] + W_lat\n",
    "plt.plot(Y[0], Y[1])\n",
    "plt.plot(xd[0], xd[1], **xdstyle)\n",
    "plt.xlabel(\"$x$ position [m]\")\n",
    "plt.ylabel(\"$y$ position [m]\")\n",
    "plt.title(\"Sensor #2 (lateral)\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3fa1a3d",
   "metadata": {},
   "source": [
    "## Linear Quadratic Estimator\n",
    "\n",
    "We now construct a linear quadratic estimator for the system usign the Kalman filter form.  This is idone using the [`create_estimator_iosystem`](https://github.com/python-control/python-control/blob/main/control/stochsys.py#L310-L517) function in python-control."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "993601a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Disturbance and initial condition model\n",
    "# Note: multiple by sampling time since we discretized the dynamics\n",
    "Rv = np.diag([0.1, 0.01]) * Ts\n",
    "# Rv = np.diag([10, 1]) * Ts       # Variant: no input information\n",
    "P0 = np.diag([1, 1, 0.1])\n",
    "\n",
    "# Combine the sensors\n",
    "# Note: no sampling time here because we are doing discrete-time KF\n",
    "C = np.vstack([C_lon, C_lat])\n",
    "Rw = sp.linalg.block_diag(Rw_lon, Rw_lat)\n",
    "\n",
    "estim = ct.create_estimator_iosystem(discsys, Rv, Rw, C=C, P0=P0)\n",
    "print(estim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c2e8ab0",
   "metadata": {},
   "source": [
    "We can now run the estimator on the noisy signals to see how well it works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d02ec33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the inputs to the estimator\n",
    "Y = np.vstack([xd[0:2] + W_lon, xd[0:2] + W_lat])\n",
    "U = np.vstack([Y, ud])        # add input to the Kalman filter\n",
    "# U = np.vstack([Y, ud * 0])  # variant: no input information\n",
    "X0 = np.hstack([xd[:, 0], P0.reshape(-1)])\n",
    "\n",
    "# Run the estimator on the trajectory\n",
    "estim_resp = ct.input_output_response(estim, timepts, U, X0)\n",
    "\n",
    "# Run a prediction to see what happens next\n",
    "T_predict = np.arange(timepts[-1], timepts[-1] + 4 + Ts, Ts)\n",
    "U_predict = np.outer(U[:, -1], np.ones_like(T_predict))\n",
    "predict_resp = ct.input_output_response(\n",
    "    estim, T_predict, U_predict, estim_resp.states[:, -1],\n",
    "    params={'correct': False})\n",
    "\n",
    "# Plot the estimated trajectory versus the actual trajectory\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.errorbar(\n",
    "    estim_resp.time, estim_resp.outputs[0], \n",
    "    estim_resp.states[estim.find_state('P[0,0]')], fmt='b-', **ebarstyle)\n",
    "plt.errorbar(\n",
    "    predict_resp.time, predict_resp.outputs[0], \n",
    "    predict_resp.states[estim.find_state('P[0,0]')], fmt='r-', **ebarstyle)\n",
    "plt.plot(timepts, xd[0], 'k--')\n",
    "plt.ylabel(\"$x$ position [m]\")\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.errorbar(\n",
    "    estim_resp.time, estim_resp.outputs[1], \n",
    "    estim_resp.states[estim.find_state('P[1,1]')], fmt='b-', **ebarstyle)\n",
    "plt.errorbar(\n",
    "    predict_resp.time, predict_resp.outputs[1], \n",
    "    predict_resp.states[estim.find_state('P[1,1]')], fmt='r-', **ebarstyle)\n",
    "# lims = plt.axis(); plt.axis([lims[0], lims[1], -5, 5])\n",
    "plt.plot(timepts, xd[1], 'k--');\n",
    "plt.ylabel(\"$y$ position [m]\")\n",
    "plt.xlabel(\"Time $t$ [s]\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f69f79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the estimated errors\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.errorbar(\n",
    "    estim_resp.time, estim_resp.outputs[0] - xd[0], \n",
    "    estim_resp.states[estim.find_state('P[0,0]')], fmt='b-', **ebarstyle)\n",
    "plt.errorbar(\n",
    "    predict_resp.time, predict_resp.outputs[0] - (xd[0] + xd[0, -1]), \n",
    "    predict_resp.states[estim.find_state('P[0,0]')], fmt='r-', **ebarstyle)\n",
    "lims = plt.axis(); plt.axis([lims[0], lims[1], -0.2, 0.2])\n",
    "# lims = plt.axis(); plt.axis([lims[0], lims[1], -2, 0.2])\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.errorbar(\n",
    "    estim_resp.time, estim_resp.outputs[1] - xd[1], \n",
    "    estim_resp.states[estim.find_state('P[1,1]')], fmt='b-', **ebarstyle)\n",
    "plt.errorbar(\n",
    "    predict_resp.time, predict_resp.outputs[1] - xd[1, -1], \n",
    "    predict_resp.states[estim.find_state('P[1,1]')], fmt='r-', **ebarstyle)\n",
    "lims = plt.axis(); plt.axis([lims[0], lims[1], -0.2, 0.2]);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f6c1b6f",
   "metadata": {},
   "source": [
    "## Things to try\n",
    "* Remove the input (and update P0 and Rv)\n",
    "* Change the sampling rate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f680b92",
   "metadata": {},
   "source": [
    "## Predictor-corrector form\n",
    "\n",
    "Instead of using create_estimator_iosystem, we can also compute out the estimate in a more manual fashion, done here using the predictor-corrector form."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa488d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# System matrices\n",
    "A, B, F = discsys.A, discsys.B, discsys.B\n",
    "\n",
    "# Create an array to store the results\n",
    "xhat = np.zeros((discsys.nstates, timepts.size))\n",
    "P = np.zeros((discsys.nstates, discsys.nstates, timepts.size))\n",
    "\n",
    "# Update the estimates at each time\n",
    "for i, t in enumerate(timepts):\n",
    "    # Prediction step\n",
    "    if i == 0:\n",
    "        # Use the initial condition\n",
    "        xkkm1 = xd[:, 0]\n",
    "        Pkkm1 = P0\n",
    "    else:\n",
    "        xkkm1 = A @ xkk + B @ ud[:, i-1]\n",
    "        Pkkm1 = A @ Pkk @ A.T + F @ Rv @ F.T\n",
    "    \n",
    "    # Correction step (variant: apply only when sensor data is available)\n",
    "    L = Pkkm1 @ C.T @ np.linalg.inv(Rw + C @ Pkkm1 @ C.T)\n",
    "    xkk = xkkm1 - L @ (C @ xkkm1 - Y[:, i])\n",
    "    Pkk = Pkkm1 - L @ C @ Pkkm1\n",
    "\n",
    "    # Save the state estimate and covariance for later plotting\n",
    "    xhat[:, i], P[:, :, i] = xkkm1, Pkkm1  # For comparison to Kalman form\n",
    "    # xhat[:, i], P[:, :, i] = xkk, Pkk    # variant: \n",
    "    \n",
    "plt.subplot(2, 1, 1)\n",
    "plt.errorbar(timepts, xhat[0], P[0, 0], fmt='b-', **ebarstyle)\n",
    "plt.plot(timepts, xd[0], 'k--')\n",
    "plt.ylabel(\"$x$ position [m]\")\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.errorbar(timepts, xhat[1], P[1, 1], fmt='b-', **ebarstyle)\n",
    "plt.plot(timepts, xd[1], 'k--')\n",
    "plt.ylabel(\"$x$ position [m]\")\n",
    "plt.xlabel(\"Time $t$ [s]\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eda4729",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the estimated errors (and compare to Kalman form)\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.errorbar(timepts, xhat[0] - xd[0], P[0, 0], fmt='b-', **ebarstyle)\n",
    "plt.plot(estim_resp.time, estim_resp.outputs[0] - xd[0], 'r--', linewidth=3)\n",
    "lims = plt.axis(); plt.axis([lims[0], lims[1], -0.2, 0.2])\n",
    "plt.ylabel(\"x error [m]\")\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.errorbar(timepts, xhat[1] - xd[1], P[1, 1], fmt='b-', **ebarstyle,\n",
    "            label='predictor/corrector')\n",
    "plt.plot(estim_resp.time, estim_resp.outputs[1] - xd[1], 'r--', linewidth=3,\n",
    "        label='Kalman form')\n",
    "lims = plt.axis(); plt.axis([lims[0], lims[1], -0.2, 0.2])\n",
    "plt.ylabel(\"y error [m]\")\n",
    "plt.xlabel(\"Time $t$ [s]\")\n",
    "plt.legend(loc='lower right');"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19a673a1",
   "metadata": {},
   "source": [
    "## Information filter\n",
    "\n",
    "An alternative way to implement the computation is using the information filter formulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36111bc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.linalg import inv\n",
    "\n",
    "# Update the estimates at each time\n",
    "for i, t in enumerate(timepts):\n",
    "    # Prediction step\n",
    "    if i == 0:\n",
    "        # Use the initial condition\n",
    "        xkkm1 = xd[:, 0]\n",
    "        Pkkm1 = P0\n",
    "    else:\n",
    "        xkkm1 = A @ xkk + B @ ud[:, i-1]\n",
    "        Pkkm1 = A @ Pkk @ A.T + F @ Rv @ F.T\n",
    "        \n",
    "    # Correction step (variant: apply only when sensor data is available)\n",
    "    Ikk, Zkk = inv(Pkkm1), inv(Pkkm1) @ xkkm1\n",
    "    \n",
    "    # Longitudinal sensor update\n",
    "    Ikk += C_lon.T @ inv(Rw_lon) @ C_lon     # Omega_lon\n",
    "    Zkk += C_lon.T @ inv(Rw_lon) @ Y[:2, i]  # Psi_lon\n",
    "\n",
    "    # Lateral sensor update\n",
    "    Ikk += C_lat.T @ inv(Rw_lat) @ C_lat     # Omega_lat\n",
    "    Zkk += C_lat.T @ inv(Rw_lat) @ Y[2:, i]  # Psi_lat\n",
    "    \n",
    "    # Compute the updated state and covariance \n",
    "    Pkk = inv(Ikk)\n",
    "    xkk = Pkk @ Zkk\n",
    "\n",
    "    # Save the state estimate and covariance for later plotting\n",
    "    xhat[:, i], P[:, :, i] = xkkm1, Pkkm1\n",
    "\n",
    "# Plot the estimated errors (and compare to Kalman form)\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.errorbar(timepts, xhat[0] - xd[0], P[0, 0], fmt='b-', **ebarstyle)\n",
    "plt.plot(estim_resp.time, estim_resp.outputs[0] - xd[0], 'r--', linewidth=3)\n",
    "lims = plt.axis(); plt.axis([lims[0], lims[1], -0.2, 0.2])\n",
    "plt.ylabel(\"x error [m]\")\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.errorbar(timepts, xhat[1] - xd[1], P[1, 1], fmt='b-', **ebarstyle,\n",
    "            label='information filter')\n",
    "plt.plot(estim_resp.time, estim_resp.outputs[1] - xd[1], 'r--', linewidth=3,\n",
    "        label='Kalman form')\n",
    "lims = plt.axis(); plt.axis([lims[0], lims[1], -0.2, 0.2])\n",
    "plt.ylabel(\"y error [m]\")\n",
    "plt.xlabel(\"Time $t$ [s]\")\n",
    "plt.legend(loc='lower right');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad5cf57f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
